#!/usr/bin/env
#-*- coding:utf-8 -*-

from Crypto.Cipher import AES
# def aes_cipher(key,mode = AES.MODE_CBC,iv = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16'):
def aes_cipher(key,mode = AES.MODE_CBC,iv = '1234567890123456'):
    '''
    allocate a aes cipher,automatic align key 256bit
    '''
    real_key = None
    if len(key) >= 16:
        real_key = key[0:16]
    else:
        need_align_bytes = 16-len(key)
        if isinstance(key,str):
            real_key = key + '\x00'*need_align_bytes
        elif isinstance(key,bytes):
            real_key = key + b'\x00'*need_align_bytes
    cipher = AES.new(real_key,mode,iv)
    return cipher

def aes_encrypt(origin,key,mode = AES.MODE_CBC,iv = '1234567890123456'):
    cipher = aes_cipher(key,mode,iv)
    #origin align
    tmp = len(origin) % AES.block_size
    if tmp != 0:
        missing_aligned_len = AES.block_size - tmp
        if isinstance(origin,str):
            aligned_origin = origin + '\x00'*missing_aligned_len
        elif isinstance(origin,bytes):
            aligned_origin = origin + b'\x00'*missing_aligned_len
        return cipher.encrypt(aligned_origin)
    else:
        return cipher.encrypt(origin)

def aes_decrypt(encrypted,key,mode = AES.MODE_CBC,iv = '1234567890123456'):
    cipher = aes_cipher(key,mode,iv)
    #encrypted align
    tmp = len(encrypted) % AES.block_size
    if tmp != 0:
        missing_aligned_len = AES.block_size - tmp
    else:
        missing_aligned_len = 0
    if missing_aligned_len != 0:
        if isinstance(encrypted,str):
            aligned_encrypted = encrypted + '\x00'*missing_aligned_len
        elif isinstance(encrypted,bytes):
            aligned_encrypted = encrypted + b'\x00'*missing_aligned_len
        return cipher.decrypt(aligned_encrypted)
    else:
        return cipher.decrypt(encrypted)